{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Image Classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook we explore standard image classification on MNIST and CIFAR10 with convolutional Neural ODE variants.\n",
    "* Depth-invariant neural ODE\n",
    "* Galerkin neural ODE (GalNODE)\n",
    "\n",
    "In the following notebooks we'll explore `augmentation` strategies that can be easily applied to the models below with the flexible `torchdyn` API. Here, we use simple `0-augmentation` (the ANODE model)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys ; sys.path.append('../')\n",
    "from torchdyn.models import *; from torchdyn import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Stefano\\anaconda3\\lib\\site-packages\\pytorch_lightning\\utilities\\distributed.py:25: UserWarning: Unsupported `ReduceOp` for distributed computing.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "from pytorch_lightning.metrics.functional import accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size=128\n",
    "size=28\n",
    "path_to_data='../data/mnist_data'\n",
    "\n",
    "all_transforms = transforms.Compose([\n",
    "    transforms.Resize(size),\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "train_data = datasets.MNIST(path_to_data, train=True, download=True,\n",
    "                            transform=all_transforms)\n",
    "test_data = datasets.MNIST(path_to_data, train=False,\n",
    "                           transform=all_transforms)\n",
    "\n",
    "trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
    "testloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The **Learner** is then defined as:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Learner(pl.LightningModule):\n",
    "    def __init__(self, model:nn.Module):\n",
    "        super().__init__()\n",
    "        self.lr = 1e-3\n",
    "        self.model = model\n",
    "        self.iters = 0.\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.model(x)\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        self.iters += 1.\n",
    "        x, y = batch   \n",
    "        x, y = x.to(device), y.to(device)\n",
    "        y_hat = self.model(x)   \n",
    "        loss = nn.CrossEntropyLoss()(y_hat, y)\n",
    "        epoch_progress = self.iters / self.loader_len\n",
    "        acc = accuracy(y_hat, y)\n",
    "        nfe = model[1].nfe ; model[1].nfe = 0\n",
    "        tqdm_dict = {'train_loss': loss, 'accuracy': acc, 'NFE': nfe}\n",
    "        logs = {'train_loss': loss, 'epoch': epoch_progress}\n",
    "        return {'loss': loss, 'progress_bar': tqdm_dict, 'log': logs}   \n",
    "\n",
    "    def test_step(self, batch, batch_nb):\n",
    "        x, y = batch\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        y_hat = self(x)\n",
    "        acc = accuracy(y_hat, y)\n",
    "        return {'test_loss': nn.CrossEntropyLoss()(y_hat, y), 'test_accuracy': acc}\n",
    "\n",
    "    def test_epoch_end(self, outputs):\n",
    "        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()\n",
    "        avg_acc = torch.stack([x['test_accuracy'] for x in outputs]).mean()\n",
    "        logs = {'test_loss': avg_loss}\n",
    "        return {'avg_test_loss': avg_loss, 'avg_test_accuracy': avg_acc,\n",
    "                'log': logs, 'progress_bar': logs}\n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        opt = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=5e-5)\n",
    "        sched = {'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(opt),\n",
    "                 'monitor': 'loss', \n",
    "                 'interval': 'step',\n",
    "                 'frequency': 10  }\n",
    "        return [opt], [sched]\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        self.loader_len = len(trainloader)\n",
    "        return trainloader\n",
    "\n",
    "    def test_dataloader(self):\n",
    "        self.test_loader_len = len(trainloader)\n",
    "        return testloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Depth-Invariant Conv Neural ODE "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "func = nn.Sequential(nn.Conv2d(11, 11, 3, padding=1),\n",
    "                     nn.Tanh(),                 \n",
    "                     ).to(device)\n",
    "\n",
    "neuralDE = NeuralDE(func, \n",
    "                   solver='rk4',\n",
    "                   sensitivity='autograd',\n",
    "                   s_span=torch.linspace(0, 1, 10)).to(device)\n",
    "\n",
    "model = nn.Sequential(Augmenter(augment_dims=10),\n",
    "                      neuralDE,\n",
    "                      nn.Conv2d(11, 1, 3, padding=1),\n",
    "                      nn.Flatten(),                     \n",
    "                      nn.Linear(28*28, 10)).to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name  | Type       | Params\n",
      "-------------------------------------\n",
      "0 | model | Sequential | 9 K   \n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cfd66a2a5a604b41a6e2fd5b92145d1b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn = Learner(model)\n",
    "trainer = pl.Trainer(max_epochs=3,\n",
    "                     gpus=1,\n",
    "                     progress_bar_refresh_rate=1,\n",
    "                     )\n",
    "\n",
    "trainer.fit(learn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3 epochs are not enough. Feel free to keep training and using all kinds of scheduling and optimization tricks :)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Galerkin Data-Controlled Conv Neural ODE (IL-Augmentation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "func = nn.Sequential(DataControl(),\n",
    "                     DepthCat(1),\n",
    "                     GalConv2d(10+10, 12, 3, padding=1, expfunc=FourierExpansion, n_harmonics=5),\n",
    "                     nn.Softplus(),\n",
    "                     DataControl(),\n",
    "                     DepthCat(1),\n",
    "                     GalConv2d(22, 10, 3, padding=1, expfunc=FourierExpansion, n_harmonics=5),\n",
    "                     nn.Tanh()\n",
    "                     )\n",
    "\n",
    "neuralDE = NeuralDE(func, \n",
    "                   solver='dopri5',\n",
    "                   sensitivity='adjoint',\n",
    "                   s_span=torch.linspace(0, 1, 2)).to(device)\n",
    "\n",
    "model = nn.Sequential(Augmenter(augment_idx=1, augment_func=nn.Conv2d(1, 9, 3, padding=1)),\n",
    "                      neuralDE,\n",
    "                      nn.Conv2d(10, 1, 3, padding=1),\n",
    "                      nn.Flatten(),                     \n",
    "                      nn.Linear(28*28, 10)).to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name  | Type       | Params\n",
      "-------------------------------------\n",
      "0 | model | Sequential | 49 K  \n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0e8f6943122c4163b035e33a668edd1a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn = Learner(model)\n",
    "trainer = pl.Trainer(max_epochs=3,\n",
    "                     gpus=1,\n",
    "                     progress_bar_refresh_rate=1,\n",
    "                     )\n",
    "\n",
    "trainer.fit(learn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3 epochs are not enough. Feel free to keep training and using all kinds of scheduling and optimization tricks :)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
